// SPDX-FileCopyrightText: Copyright 2024 shadPS4 Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later

#pragma once

#include <span>
#include <vector>
#include <boost/container/static_vector.hpp>
#include "common/assert.h"
#include "common/types.h"
#include "shader_recompiler/backend/bindings.h"
#include "shader_recompiler/frontend/copy_shader.h"
#include "shader_recompiler/frontend/tessellation.h"
#include "shader_recompiler/ir/attribute.h"
#include "shader_recompiler/ir/passes/srt.h"
#include "shader_recompiler/ir/reg.h"
#include "shader_recompiler/ir/type.h"
#include "shader_recompiler/params.h"
#include "shader_recompiler/resource.h"
#include "shader_recompiler/runtime_info.h"

namespace Serialization {
struct Archive;
}

namespace Shader {

enum class Qualifier : u8 {
    None,
    Smooth,
    NoPerspective,
    PerVertex,
    Flat,
    Centroid,
    Sample,
};

/**
 * Contains general information generated by the shader recompiler for an input program.
 */
struct InfoPersistent {
    BufferResourceList buffers;
    ImageResourceList images;
    SamplerResourceList samplers;
    FMaskResourceList fmasks;

    struct UserDataMask {
        void Set(IR::ScalarReg reg) noexcept {
            mask |= 1 << static_cast<u32>(reg);
        }

        u32 Index(IR::ScalarReg reg) const noexcept {
            const u32 reg_mask = (1 << static_cast<u32>(reg)) - 1;
            return std::popcount(mask & reg_mask);
        }

        u32 NumRegs() const noexcept {
            return std::popcount(mask);
        }

        u32 mask;
    };
    UserDataMask ud_mask{};
    u32 fetch_shader_sgpr_base{};

    u64 pgm_hash{};

    s32 tess_consts_dword_offset = -1;
    IR::ScalarReg tess_consts_ptr_base = IR::ScalarReg::Max;
    Stage stage;
    LogicalStage l_stage;

    u8 mrt_mask{};
    bool has_fetch_shader{};
    bool has_bitwise_xor{};
    bool uses_dma{};

    InfoPersistent() = default;
    InfoPersistent(Stage stage_, LogicalStage l_stage_, u64 pgm_hash_)
        : stage{stage_}, l_stage{l_stage_}, pgm_hash{pgm_hash_} {}
};

struct Info : InfoPersistent {
    struct AttributeFlags {
        bool Get(IR::Attribute attrib, u32 comp = 0) const {
            return flags[Index(attrib)] & (1 << comp);
        }

        bool GetAny(IR::Attribute attrib) const {
            return flags[Index(attrib)];
        }

        void Set(IR::Attribute attrib, u32 comp = 0) {
            flags[Index(attrib)] |= (1 << comp);
        }

        u32 NumComponents(IR::Attribute attrib) const {
            return 4;
        }

        static size_t Index(IR::Attribute attrib) {
            return static_cast<size_t>(attrib);
        }

        std::array<u8, IR::NumAttributes> flags;
    };

    enum class ReadConstType {
        None = 0,
        Immediate = 1 << 0,
        Dynamic = 1 << 1,
    };

    struct Interpolation {
        Qualifier primary;
        Qualifier auxiliary;
    };

    std::span<const u32> user_data;
    std::vector<u32> flattened_ud_buf;
    PersistentSrtInfo srt_info;

    AttributeFlags loads{};
    AttributeFlags stores{};

    ReadConstType readconst_types{};
    CopyShaderData gs_copy_data;
    u32 uses_patches{};

    VAddr pgm_base;
    bool has_storage_images{};
    bool has_discard{};
    bool has_image_gather{};
    bool has_image_query{};
    bool uses_buffer_atomic_float_min_max{};
    bool uses_image_atomic_float_min_max{};
    bool uses_lane_id{};
    bool uses_group_quad{};
    bool uses_group_ballot{};
    IR::Type shared_types{};
    bool uses_fp16{};
    bool uses_fp64{};
    bool uses_pack_10_11_11{};
    bool uses_unpack_10_11_11{};
    bool uses_buffer_int64_atomics{};
    bool uses_shared_int64_atomics{};
    bool stores_tess_level_outer{};
    bool stores_tess_level_inner{};
    bool translation_failed{};

    std::array<Interpolation, IR::NumParams> fs_interpolation{};

    Info() = default;
    Info(Stage stage_, LogicalStage l_stage_, ShaderParams params)
        : InfoPersistent(stage_, l_stage_, params.hash), pgm_base{params.Base()},
          user_data{params.user_data} {}

    template <typename T>
    inline T ReadUdSharp(u32 sharp_idx) const noexcept {
        return *reinterpret_cast<const T*>(&flattened_ud_buf[sharp_idx]);
    }

    template <typename T>
    T ReadUdReg(u32 ptr_index, u32 dword_offset) const noexcept {
        T data;
        const u32* base = user_data.data();
        if (ptr_index != IR::NumScalarRegs) {
            std::memcpy(&base, &user_data[ptr_index], sizeof(base));
            base = reinterpret_cast<const u32*>(VAddr(base) & 0xFFFFFFFFFFFFULL);
        }
        std::memcpy(&data, base + dword_offset, sizeof(T));
        return data;
    }

    void PushUd(Backend::Bindings& bnd, PushData& push) const {
        u32 mask = ud_mask.mask;
        while (mask) {
            const u32 index = std::countr_zero(mask);
            ASSERT(bnd.user_data < NUM_USER_DATA_REGS && index < NUM_USER_DATA_REGS);
            mask &= ~(1U << index);
            push.ud_regs[bnd.user_data++] = user_data[index];
        }
    }

    void AddBindings(Backend::Bindings& bnd) const {
        bnd.buffer += buffers.size();
        bnd.unified += buffers.size() + images.size() + samplers.size();
        bnd.user_data += ud_mask.NumRegs();
    }

    void RefreshFlatBuf() {
        flattened_ud_buf.resize(srt_info.flattened_bufsize_dw);
        ASSERT(user_data.size() <= NUM_USER_DATA_REGS);
        std::memcpy(flattened_ud_buf.data(), user_data.data(), user_data.size_bytes());
        if (srt_info.walker_func) {
            srt_info.walker_func(user_data.data(), flattened_ud_buf.data());
        }
    }

    void ReadTessConstantBuffer(TessellationDataConstantBuffer& tess_constants) const {
        ASSERT(tess_consts_dword_offset >= 0); // We've already tracked the V# UD
        auto buf = ReadUdReg<AmdGpu::Buffer>(static_cast<u32>(tess_consts_ptr_base),
                                             static_cast<u32>(tess_consts_dword_offset));
        VAddr tess_constants_addr = buf.base_address;
        memcpy(&tess_constants,
               reinterpret_cast<TessellationDataConstantBuffer*>(tess_constants_addr),
               sizeof(tess_constants));
    }

    void Serialize(Serialization::Archive& ar) const;
    bool Deserialize(Serialization::Archive& ar);
};
DECLARE_ENUM_FLAG_OPERATORS(Info::ReadConstType);

} // namespace Shader
